import os
import cv2
# from PIL import Image
import PIL
import numpy as np
import mindspore as ms 
from mindspore import nn
import numpy as np
import lmdb
import logging
import re
import six
from pathlib import Path
from ms_utils import CharsetMapper,onehot
import mindspore.dataset as ds
import math
from ms_transform import CVColorJitter, CVDeterioration, CVGeometry
import random
import warnings
import pandas
from mindspore import context



class ms_ImageDataset:

    def __init__(self,
                 path:str,
                 is_training:bool=True,
                 img_h:int=32,
                 img_w:int=128,#原本是100,详见文档里第25条
                 max_length:int=25,
                 check_length:bool=True,
                 case_sensitive:bool=False,
                 charset_path:str='data/charset_36.txt',
                 convert_mode:str='RGB',
                 data_aug:bool=True,
                 deteriorate_ratio:float=0.,
                 multiscales:bool=True,
                 one_hot_y:bool=True,
                 return_idx:bool=False,
                 return_raw:bool=False,
                 **kwargs):
        self.path, self.name = Path(path), Path(path).name
        assert self.path.is_dir() and self.path.exists(), f"{path} is not a valid directory."
        self.convert_mode, self.check_length = convert_mode, check_length
        self.img_h, self.img_w = img_h, img_w
        self.max_length, self.one_hot_y = max_length, one_hot_y
        self.return_idx, self.return_raw = return_idx, return_raw
        self.case_sensitive, self.is_training = case_sensitive, is_training
        self.data_aug, self.multiscales = data_aug, multiscales
        self.charset = CharsetMapper(charset_path, max_length=max_length+1)
        self.c = self.charset.num_classes

        self.env = lmdb.open(str(path), readonly=True, lock=False, readahead=False, meminit=False)
        assert self.env, f'Cannot open LMDB dataset from {path}.'
        with self.env.begin(write=False) as txn:
            self.length = int(txn.get('num-samples'.encode()))


        if self.is_training and self.data_aug:
            self.augment_tfs = ds.transforms.Compose([
                CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5),
                CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
                CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25)
            ])
        self.totensor = ds.vision.ToTensor()
        # self.totensor = transforms.ToTensor()
    
    def __len__(self): return self.length

    def _next_image(self, index):
        next_index = random.randint(0, len(self) - 1)
        return self.get(next_index)

    def _check_image(self, x, pixels=6):
        if x.size[0] <= pixels or x.size[1] <= pixels: return False
        else: return True

    def resize_multiscales(self, img, borderType=cv2.BORDER_CONSTANT): 
        def _resize_ratio(img, ratio, fix_h=True):
            if ratio * self.img_w < self.img_h:
                if fix_h: trg_h = self.img_h
                else: trg_h = int(ratio * self.img_w)
                trg_w = self.img_w
            else: trg_h, trg_w = self.img_h, int(self.img_h / ratio)
            img = cv2.resize(img, (trg_w, trg_h))
            pad_h, pad_w = (self.img_h - trg_h) / 2, (self.img_w - trg_w) / 2
            top, bottom = math.ceil(pad_h), math.floor(pad_h)
            left, right = math.ceil(pad_w), math.floor(pad_w)
            img = cv2.copyMakeBorder(img, top, bottom, left, right, borderType)
            return img
        
        if self.is_training: 
            if random.random() < 0.5:
                base, maxh, maxw = self.img_h, self.img_h, self.img_w
                h, w = random.randint(base, maxh), random.randint(base, maxw)
                return _resize_ratio(img, h/w)
            else: return _resize_ratio(img, img.shape[0] / img.shape[1])  # keep aspect ratio
        else:  return _resize_ratio(img, img.shape[0] / img.shape[1])  # keep aspect ratio

    def resize(self, img):
        if self.multiscales: return self.resize_multiscales(img, cv2.BORDER_REPLICATE)
        else: return cv2.resize(img, (self.img_w, self.img_h))
         
    def get(self, idx):
        with self.env.begin(write=False) as txn:
            image_key, label_key = f'image-{idx+1:09d}', f'label-{idx+1:09d}'
            try:
                label = str(txn.get(label_key.encode()), 'utf-8')  # label
                label = re.sub('[^0-9a-zA-Z]+', '', label)
                if self.check_length and self.max_length > 0:
                    if len(label) > self.max_length or len(label) <= 0: #如果label的长度超过max_length或小于0，直接next_image
                        #logging.info(f'Long or short text image is found: {self.name}, {idx}, {label}, {len(label)}')
                        return self._next_image(idx)
                label = label[:self.max_length] #如果label的长度是合理的，那就正常取

                imgbuf = txn.get(image_key.encode())  # image
                buf = six.BytesIO()
                buf.write(imgbuf)
                buf.seek(0)
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", UserWarning) # EXIF warning from TiffPlugin
                    image = PIL.Image.open(buf).convert(self.convert_mode)
                if self.is_training and not self._check_image(image):
                    #logging.info(f'Invalid image is found: {self.name}, {idx}, {label}, {len(label)}')
                    return self._next_image(idx)
            except:
                import traceback
                traceback.print_exc()
                logging.info(f'Corrupted image is found: {self.name}, {idx}, {label}, {len(label)}')
                return self._next_image(idx)
            
            #这里的image是PIL.Image类型的，没法用Mindspore直接加载，因此用np转换了一下
            image = np.array(image)
            return image, label, idx

    def _process_training(self, image):
        if self.data_aug: image = self.augment_tfs(image)
        image = self.resize(np.array(image))
        return image

    def _process_test(self, image):
        
        return self.resize(np.array(image)) # TODO:move is_training to here

    def __getitem__(self, idx):
        image, text, idx_new = self.get(idx)


        if not self.is_training: assert idx == idx_new, f'idx {idx} != idx_new {idx_new} during testing.'

        # if self.is_training: image = self._process_training(image) # 因为按照源码读取数据太慢了
        # else: image = self._process_test(image)                    # 所以转为在map处理数据
        # if self.return_raw: return image, text
        # image = self.totensor(image)




        # image = ms.Tensor(image)
        
        # # length = tensor(len(text) + 1).to(dtype=torch.long)  # one for end token
        # # label = self.charset.get_labels(text, case_sensitive=self.case_sensitive)
        # # label = tensor(label).to(dtype=torch.long)
        # length = ms.Tensor((len(text) + 1), dtype = ms.int64)
        # label = self.charset.get_labels(text, case_sensitive=self.case_sensitive)
        # label = ms.Tensor(label,dtype=ms.int64) 

        length = len(text) + 1
        length = float(length)
        

        label = self.charset.get_labels(text, case_sensitive=self.case_sensitive) # 因为按照源码读取数据太慢了
        if self.one_hot_y: label = onehot(label, self.charset.num_classes)  # 所以转为在map处理数据
        



        # if self.return_idx: 
        #     y = [label, length, idx_new]
        #     return image, y[0], y[1], y[2]
        # else:
        #     y = [label, length]

        return image, label, length#本来还应该有length的，不过Mindspore自带的model.eval应该是要求dataset返回两个值
                                 #暂不知length有什么用，所以注释掉了
                                 #现在好像知道了，因为在计算loss的时候，需要gt_lengths。所以只能自己定义eval了


def data_for_train_image(image):

    toTensor = ds.vision.ToTensor() 

    def resize_multiscales( img, borderType=cv2.BORDER_CONSTANT): 
        img_h:int=32
        img_w:int=128
        def _resize_ratio(img, ratio, fix_h=True):
            if ratio * img_w < img_h:
                if fix_h: trg_h = img_h
                else: trg_h = int(ratio * img_w)
                trg_w = img_w
            else: trg_h, trg_w = img_h, int(img_h / ratio)
            img = cv2.resize(img, (trg_w, trg_h))
            pad_h, pad_w = (img_h - trg_h) / 2, (img_w - trg_w) / 2
            top, bottom = math.ceil(pad_h), math.floor(pad_h)
            left, right = math.ceil(pad_w), math.floor(pad_w)
            img = cv2.copyMakeBorder(img, top, bottom, left, right, borderType)
            return img
        
         
        if random.random() < 0.5:
            base, maxh, maxw = img_h, img_h, img_w
            h, w = random.randint(base, maxh), random.randint(base, maxw)
            return _resize_ratio(img, h/w)
        else: return _resize_ratio(img, img.shape[0] / img.shape[1])  # keep aspect ratio
        
    data_aug = ds.transforms.Compose([
                CVGeometry(degrees=45, translate=(0.0, 0.0), scale=(0.5, 2.), shear=(45, 15), distortion=0.5, p=0.5),
                CVDeterioration(var=20, degrees=6, factor=4, p=0.25),
                CVColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=0.25)
            ])
    image = data_aug(image)
    image = resize_multiscales(image, cv2.BORDER_REPLICATE)
    image = toTensor(image)
    return image

def data_for_train_label(label):

    charset = CharsetMapper('data/charset_36.txt', max_length=26)
    label = charset.get_labels(label, case_sensitive=False)
    label = onehot(label,charset.num_classes)
    return label





dataset_train_image_generator = ms_ImageDataset(path='/home/data4/zyh/ABINet/data/evaluation/IIIT5k_3000')
data_after_map = ds.GeneratorDataset(dataset_train_image_generator, 
                    column_names= ['image','label','length'], 
                    shuffle=False,
                    python_multiprocessing=True,num_parallel_workers=4)
#data_after_map = data_after_map.map(operations=data_for_train_label,input_columns="label",python_multiprocessing=True,num_parallel_workers=4)
data_after_map = data_after_map.map(operations=data_for_train_image,input_columns="image",python_multiprocessing=True,num_parallel_workers=10)
image_dataset = data_after_map.batch(batch_size= 72, drop_remainder=True)
dataset = image_dataset        
        

# data_loader = dataset
# for batch_idx, (data, target,length) in enumerate(data_loader):
#     print(batch_idx, data.shape, target.shape,length.shape)
#     print("="*20)

